"""
/** Copyright 2020 Zhejiang Lab and Zhejiang University. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
"""
from django.db import models
import django
from my_models.models import Models
from my_datasets.models import Datasets
import pickle
import logging
from datetime import datetime
import random
# from functools import reduce
logger_name = 'task.models'
logger = logging.getLogger(logger_name)

# Create your models here.

TASK_OPTIONS = {
    "tasks": {
        "segmentation": "图像分割",
        "classification": "图像分类",
        "detection": "物体检测",
        "depth": "深度估计",
        "keypoints": "关键点检测",
    }
}

class GeneralTaskType(models.TextChoices):
    PLAINTRAIN = 'PT', 'Plain Train'
    KNOWREORG = 'KA', 'Knowledge Amalgamation'

class Task(models.Model):
    task_uid = models.CharField(max_length = 32, unique=True, default=0)
    distributed = models.BooleanField(default=False)
    task_type = models.CharField(
        max_length=2,
        choices=GeneralTaskType.choices,
        default=GeneralTaskType.KNOWREORG,
    )
    task = models.BinaryField(max_length = 1024)
    result = models.BinaryField(max_length = 1024, blank = True)
    user = models.ForeignKey('user.User', on_delete = models.CASCADE, related_name = 'task')
    server = models.ForeignKey('server.Server', on_delete = models.CASCADE, related_name = 'server', null = True)
    created_time = models.DateTimeField(auto_now_add=True)
    started_time = models.DateTimeField(null = True)
    completed_time = models.DateTimeField(null = True)
    note = models.BinaryField(max_length = 1024, blank = True)

    def __str__(self):
        return 'Task: {}'.format(self.task)

    def _getAllInfo(self):
        res = {}

        # Parse the task instance.
        task = unserialize(self.task)
        logger.debug(task.tasks)
        logger.debug(task.datasets)
        logger.debug(task.teacher_models)
        logger.debug(task.student_models)
        logger.debug(task.algorithms)
        target_tasks = []
        for t_id in task.tasks:
            target_tasks.append(TASK_OPTIONS['tasks'][t_id])

        if task.datasets: 
            datasets_ids = [ tm["id"] for tm in task.datasets]
            datasets = Datasets.objects.filter(id__in=datasets_ids).values_list("dataset_name", flat=True)
        else:
            datasets = None

        if task.teacher_models:
            teacher_model_ids = [ tm["id"] for tm in task.teacher_models]
            teacher_models = Models.objects.filter(id__in=teacher_model_ids).values_list("model_name", flat=True)
        else:
            teacher_models = None
        if task.student_models:
            student_model_ids = [ tm["id"] for tm in task.student_models]
            student_models = Models.objects.filter(id__in=student_model_ids).values_list("model_name", flat=True)
        else:
            student_models = None

        res.update({'tasks': target_tasks, 'datasets': datasets, 'teacher_models': teacher_models, 'student_models': student_models,
            'algorithms': task.algorithms})

        if self.started_time and self.note:
            # Parse the note if this task was started
            logger.debug("started_time: {}".format(self.started_time))
            note = unserialize(self.note)
            res.update(note.kwargs)
            logger.debug(note.kwargs)
        else:
            res["progress"] = 0

        if self.completed_time and self.result:
            # Parse the result if it was finished.
            logger.debug("completed_time: {}".format(self.completed_time))
            result = unserialize(self.result)
            logger.debug(result.result)
            res.update(result.result)
            res["progress"] = 1

        res.update({'id': self.id, 'task_type': GeneralTaskType(self.task_type).label, 'created_time': self.created_time, 'started_time': self.started_time, 'completed_time': self.completed_time})
        logger.debug(res)
        return res

    @property
    def summary(self):
        res = self._getAllInfo()
        res.pop("teacher_models", None)
        res.pop("student_models", None)
        res.pop("vis_data", None)
        res.pop("stage", None)
        return res

    @property
    def details(self):
        res = self._getAllInfo()
        # num = random.randint(2,10)
        # scalar1 = {
        #     "title": "loss",
        #     "data_type": "scalar",
        #     "data": [
        #         [datetime.utcnow(), i, random.random()*100] for i in range(num)
        #     ]
        # }
        # scalar2 = {
        #     "title": "PA",
        #     "data_type": "scalar",
        #     "data": [
        #         ["2019-04-23T18:25:43.511Z", -1, 13.4],
        #         ["2019-04-23T18:25:43.511Z", 2, 40],
        #         ["2019-04-23T18:25:43.511Z", 3, 79],
        #         ["2019-04-23T18:25:43.511Z", 50, 99.9]
        #     ]
        # }
        # urls = ["https://s1.ax1x.com/2020/04/01/G3lRD1.jpg", 
        #     "https://s1.ax1x.com/2020/04/01/G3ly34.jpg", 
        #     "https://s1.ax1x.com/2020/04/01/G3lsCF.jpg", 
        #     "https://s1.ax1x.com/2020/04/01/G3lD4U.jpg"]
        
        # image1 = {
        #     "title": "image1",
        #     "data_type": "image",
        #     "data": [
        #         [datetime.utcnow(), i, random.choice(urls) ] for i in range(num)
        #     ]
        # }
        # image2 = {
        #     "title": "image2",
        #     "data_type": "image",
        #     "data": [
        #         ["2019-04-23T18:25:43.511Z", 1, "https://s1.ax1x.com/2020/04/01/G3lD4U.jpg"],
        #         ["2019-04-23T18:25:43.511Z", 25, "https://s1.ax1x.com/2020/04/01/G3lsCF.jpg"],
        #         ["2019-04-23T18:25:43.511Z", 26, "https://s1.ax1x.com/2020/04/01/G3ly34.jpg"],
        #         ["2019-04-23T18:25:43.511Z", 30, "https://s1.ax1x.com/2020/04/01/G3lRD1.jpg"]
        #     ]
        # }
        # num = random.randint(0,1)
        # if num == 0:
        #     res['vis_data'] = [scalar1, scalar2, image1, image2]
        # elif num == 1:
        #     res.pop("vis_data", None)

        # logger.debug(res)
        return res


class Algorithm(models.Model): 
    class Meta:
        db_table = 'algorithm'

    alg_name = models.CharField(max_length=128, unique=True)

    def __str__(self):
        return 'Algorithm: {}'.format(self.alg_name)


class AlgorithmField(models.Model):
    class Meta:
        db_table = 'algorithm_field'
        unique_together = ['alg', 'field_name']
    alg = models.ForeignKey('task.Algorithm', on_delete = models.CASCADE, related_name = 'alg')
    field_name = models.CharField(max_length = 128)
    field_value = models.CharField(max_length = 128, blank=True)
    field_note = models.CharField(max_length = 512, blank=True)
    
def serialize(obj):
    return pickle.dumps(obj)

def unserialize(obj):
    return pickle.loads(obj)
